Лабораторная работа 4: Семантическая сегментация с использованием PyTorch¶

Цели работы¶

Цель — разработать и обучить сверточную нейронную сеть для задачи мультиклассовой семантической сегментации изображений на наборе данных SUIM с использованием PyTorch.

Набор данных.¶

  1. Данные содержат 8 классов. Маска сегментации имеет вид трехканального изображения с пикселями, значения которых равно либо 0, либо 255, например, (0, 0, 0), (0, 0, 255) и так далее. Помимо этого встречаются и промежуточные значения, отличные от 0 и 255. В рамках данной лабораторной работы предлагается следующее преобразование: значения маски, меньшие 128, нужно установить в 0, а значения, равные или больше 128, установить в 255.

  2. Для упрощения работы рекомендуется объединить следующие классы в один:

  • класс 2 - Aquatic plants and sea-grass
  • класс 3 - Wrecks and ruins
  • класс 5 - Reefs and invertebrates
  • класс 7 - Sea-floor and rocks

Требования¶

  1. Необходимо выполнить и отобразить в Jupyter следующие задачи:

    • Загрузка и проверка данных. Загрузить и предобработать данные с демонстрацией избранных изображений и соответствующих масок, чтобы подтвердить корректность загрузки и соответствие размерностей данных.
    • Реализация архитектуры сети. Разработать архитектуру нейронной сети для сегментации с использованием фреймворка PyTorch.
    • Настройка гиперпараметров обучения. Настроить параметры обучения, такие как функция ошибки, размер сети, скорость обучения и другие параметры.
    • Тестирование модели. После завершения обучения для оценки качества работы необходимо посчитать accuracy, IoU и визуализировать confusion matrix (с нормализацией, normalize='true').
    • Визуализация результатов. После завершения обучения построить и отобразить результаты сегментации на тестовых изображениях, сравнивая с реальными масками сегментации.
  2. Выбор архитектуры:

  • Можно использовать или адаптировать известные архитектуры глубокого обучения.
  • Может быть полезным:
    • уменьшить количество параметров в нейронной сети и размер входного изображения для ускорения сходимости, предотвращения переобучения и ускорения работы нейронной сети.
    • использовать аугментацию данных и взвешенные/специализированные функции ошибки. При аугментации данных необходимо учитывать связь изображений с маской классов.
  • Использовать перенос знаний недопустимо.
In [142]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import confusion_matrix
import seaborn as sea

import matplotlib.pyplot as plt
from torchsummary import summary
from PIL import Image
import numpy as np
import time

Загрузка и проверка корректности данных

In [ ]:
import os
import shutil

from google.colab import drive
drive.mount('/content/drive')

def copy_files_recursive(source_folder, destination_folder):
    for root, dirs, files in os.walk(source_folder):
        for file in files:
            source_path = os.path.join(root, file)
            destination_path = os.path.join(destination_folder, os.path.relpath(source_path, source_folder))

            os.makedirs(os.path.dirname(destination_path), exist_ok=True)

            shutil.copyfile(source_path, destination_path)
Mounted at /content/drive
In [ ]:
remote_root = '/content/drive/MyDrive/SUIM'
root = '/content/SUIM'
copy_files_recursive(remote_root, root)
In [ ]:
number_classes = 5

classes = {
    "background": [(0, 0, 0)],
    "human_divers": [(0, 0, 1)],
    "robots": [(1, 0, 0)],
    "fish_vertebrates": [(1, 1, 0)],
    "other": [(0, 1, 0), (0, 1, 1), (1, 0, 1), (1, 1, 1)]
}

color_classes = [[0, 0, 0],
                 [0, 0, 1],
                 [1, 0, 0],
                 [1, 1, 0],
                 [0, 1, 0]]
In [ ]:
class CustomDataset(Dataset):
    def __init__(self, images, masks):
        self.images = torch.tensor(images, dtype = torch.float)
        self.masks = torch.tensor(masks, dtype = torch.float)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        return image, mask


def load_dataset(root_images, root_masks, image_size):
    images = []
    list_dir = sorted(os.listdir(root_images))
    for file_name in list_dir:
        file_path = os.path.join(root_images, file_name)
        if os.path.isfile(file_path):
            with Image.open(file_path) as image:
                resized_image = np.array(image.resize(image_size)) / 255
                images.append(resized_image)

    labels = []
    list_dir = sorted(os.listdir(root_masks))
    for file_name in list_dir:
        file_path = os.path.join(root_masks, file_name)
        if os.path.isfile(file_path):
            with Image.open(file_path) as mask:
                background = np.zeros(image_size)
                human_divers = np.zeros(image_size)
                robots = np.zeros(image_size)
                fish_vertebrates = np.zeros(image_size)
                other = np.zeros(image_size)

                resized_mask = np.array(mask.resize(image_size)) / 255
                resized_mask = np.where(resized_mask < 0.5, 0, 1)

                for i in range(image_size[0]):
                    for j in range(image_size[1]):
                        if np.all(resized_mask[i, j] == classes["background"], axis = -1):
                            background[i, j] = 1
                        elif np.all(resized_mask[i, j] == classes["human_divers"], axis = -1):
                            human_divers[i, j] = 1
                        elif np.all(resized_mask[i, j] == classes["robots"], axis = -1):
                            robots[i, j] = 1
                        elif np.all(resized_mask[i, j] == classes["fish_vertebrates"], axis = -1):
                            fish_vertebrates[i, j] = 1
                        else:
                            other[i, j] = 1

                labels.append(np.stack([background, human_divers, robots, fish_vertebrates, other], -1))

    images = np.array(images)
    labels = np.array(labels)
    dataset = CustomDataset(images, labels)

    return dataset
In [ ]:
def dataset_info(dataset):
    print("Размер датасета изображений:", dataset.images.shape)
    print("Размер датасета масок:", dataset.masks.shape)
    print()

    number_pixels = {'Background': np.count_nonzero(dataset.masks[:, :, :, 0] == 1),
                     'Human divers': np.count_nonzero(dataset.masks[:, :, :, 1] == 1),
                     'Robots': np.count_nonzero(dataset.masks[:, :, :, 2] == 1),
                     'Fish and vertebrates': np.count_nonzero(dataset.masks[:, :, :, 3] == 1),
                     'Other': np.count_nonzero(dataset.masks[:, :, :, 4] == 1)}

    for key, value in number_pixels.items():
        print(f'Класс: {key}, Число пикселей: {value}')
In [ ]:
image_size = 80
In [ ]:
root_train_images = "/content/SUIM/train_val/images"
root_train_masks = "/content/SUIM/train_val/masks"
train_dataset = load_dataset(root_train_images, root_train_masks, (image_size, image_size))
In [ ]:
root_test_images = "/content/SUIM/TEST/images"
root_test_masks = "/content/SUIM/TEST/masks"
test_dataset = load_dataset(root_test_images, root_test_masks, (image_size, image_size))
In [ ]:
print('Train dataset:')
dataset_info(train_dataset)
print()
print('Test dataset:')
dataset_info(test_dataset)
Train dataset:
Размер датасета изображений: torch.Size([1525, 80, 80, 3])
Размер датасета масок: torch.Size([1525, 80, 80, 5])

Класс: Background, Число пикселей: 3034338
Класс: Human divers, Число пикселей: 184114
Класс: Robots, Число пикселей: 37740
Класс: Fish and vertebrates, Число пикселей: 767257
Класс: Other, Число пикселей: 5736551

Test dataset:
Размер датасета изображений: torch.Size([110, 80, 80, 3])
Размер датасета масок: torch.Size([110, 80, 80, 5])

Класс: Background, Число пикселей: 282598
Класс: Human divers, Число пикселей: 20661
Класс: Robots, Число пикселей: 4557
Класс: Fish and vertebrates, Число пикселей: 54083
Класс: Other, Число пикселей: 342101
In [97]:
def plot_images_with_masks(dataset, title):
    vert_size = 6
    horiz_size = 3
    fig, axes = plt.subplots(vert_size, horiz_size * 2, figsize = (15, 15))
    fig.suptitle(title)

    mask_sizes = (image_size, image_size, 3)

    count_images = vert_size * horiz_size
    for number in range(count_images):
        i = number // horiz_size
        j = number % horiz_size

        image, mask = dataset[number]

        axes[i, j * 2].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 2].axis('off')

        rgb_mask = np.zeros(mask_sizes)
        for k in range(number_classes):
            rgb_mask[mask[:, :, k] > 0] = color_classes[k]

        axes[i, j * 2 + 1].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 2 + 1].imshow(rgb_mask, alpha = 0.35)
        axes[i, j * 2 + 1].axis('off')

    plt.tight_layout()
    plt.show()
In [ ]:
plot_images_with_masks(train_dataset, 'Examples from train dataset')
No description has been provided for this image
In [98]:
plot_images_with_masks(test_dataset, 'Examples from test dataset')
No description has been provided for this image
In [ ]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device: {device}')
Device: cuda

Эксперименты¶

Гиперпараметры обучения

In [99]:
learning_rate = 0.01
epochs = 80
batch_size = 36

Разбиение датасетов на батчи

In [100]:
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)
In [101]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 'same'),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.25),
            #nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 'same'),
            #nn.BatchNorm2d(out_channels),
            #nn.ReLU()
        )

    def forward(self, x):
        return self.double_conv(x)


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x):
        x = self.max_pool(x)
        return self.conv(x)


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim = 1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels = 3, n_classes = 5):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = DownSample(64, 128)
        self.down2 = DownSample(128, 256)
        self.down3 = DownSample(256, 512)
        self.dropout = nn.Dropout(0.5)
        self.down4 = DownSample(512, 1024)

        self.up1 = UpSample(1024 + 512, 512)
        self.up2 = UpSample(512 + 256, 256)
        self.up3 = UpSample(256 + 128, 128)
        self.up4 = UpSample(128 + 64, 64)

        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.dropout(x4)
        x6 = self.down4(x5)
        x = self.dropout(x6)

        x = self.up1(x, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        logits = self.outc(x)
        return torch.sigmoid(logits)
In [145]:
net = UNet().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr = learning_rate)
summary(net, (3, image_size, image_size))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 80, 80]           1,792
       BatchNorm2d-2           [-1, 64, 80, 80]             128
              ReLU-3           [-1, 64, 80, 80]               0
           Dropout-4           [-1, 64, 80, 80]               0
        DoubleConv-5           [-1, 64, 80, 80]               0
         MaxPool2d-6           [-1, 64, 40, 40]               0
            Conv2d-7          [-1, 128, 40, 40]          73,856
       BatchNorm2d-8          [-1, 128, 40, 40]             256
              ReLU-9          [-1, 128, 40, 40]               0
          Dropout-10          [-1, 128, 40, 40]               0
       DoubleConv-11          [-1, 128, 40, 40]               0
       DownSample-12          [-1, 128, 40, 40]               0
        MaxPool2d-13          [-1, 128, 20, 20]               0
           Conv2d-14          [-1, 256, 20, 20]         295,168
      BatchNorm2d-15          [-1, 256, 20, 20]             512
             ReLU-16          [-1, 256, 20, 20]               0
          Dropout-17          [-1, 256, 20, 20]               0
       DoubleConv-18          [-1, 256, 20, 20]               0
       DownSample-19          [-1, 256, 20, 20]               0
        MaxPool2d-20          [-1, 256, 10, 10]               0
           Conv2d-21          [-1, 512, 10, 10]       1,180,160
      BatchNorm2d-22          [-1, 512, 10, 10]           1,024
             ReLU-23          [-1, 512, 10, 10]               0
          Dropout-24          [-1, 512, 10, 10]               0
       DoubleConv-25          [-1, 512, 10, 10]               0
       DownSample-26          [-1, 512, 10, 10]               0
          Dropout-27          [-1, 512, 10, 10]               0
        MaxPool2d-28            [-1, 512, 5, 5]               0
           Conv2d-29           [-1, 1024, 5, 5]       4,719,616
      BatchNorm2d-30           [-1, 1024, 5, 5]           2,048
             ReLU-31           [-1, 1024, 5, 5]               0
          Dropout-32           [-1, 1024, 5, 5]               0
       DoubleConv-33           [-1, 1024, 5, 5]               0
       DownSample-34           [-1, 1024, 5, 5]               0
          Dropout-35           [-1, 1024, 5, 5]               0
         Upsample-36         [-1, 1024, 10, 10]               0
           Conv2d-37          [-1, 512, 10, 10]       7,078,400
      BatchNorm2d-38          [-1, 512, 10, 10]           1,024
             ReLU-39          [-1, 512, 10, 10]               0
          Dropout-40          [-1, 512, 10, 10]               0
       DoubleConv-41          [-1, 512, 10, 10]               0
         UpSample-42          [-1, 512, 10, 10]               0
         Upsample-43          [-1, 512, 20, 20]               0
           Conv2d-44          [-1, 256, 20, 20]       1,769,728
      BatchNorm2d-45          [-1, 256, 20, 20]             512
             ReLU-46          [-1, 256, 20, 20]               0
          Dropout-47          [-1, 256, 20, 20]               0
       DoubleConv-48          [-1, 256, 20, 20]               0
         UpSample-49          [-1, 256, 20, 20]               0
         Upsample-50          [-1, 256, 40, 40]               0
           Conv2d-51          [-1, 128, 40, 40]         442,496
      BatchNorm2d-52          [-1, 128, 40, 40]             256
             ReLU-53          [-1, 128, 40, 40]               0
          Dropout-54          [-1, 128, 40, 40]               0
       DoubleConv-55          [-1, 128, 40, 40]               0
         UpSample-56          [-1, 128, 40, 40]               0
         Upsample-57          [-1, 128, 80, 80]               0
           Conv2d-58           [-1, 64, 80, 80]         110,656
      BatchNorm2d-59           [-1, 64, 80, 80]             128
             ReLU-60           [-1, 64, 80, 80]               0
          Dropout-61           [-1, 64, 80, 80]               0
       DoubleConv-62           [-1, 64, 80, 80]               0
         UpSample-63           [-1, 64, 80, 80]               0
           Conv2d-64            [-1, 5, 80, 80]             325
          OutConv-65            [-1, 5, 80, 80]               0
================================================================
Total params: 15,678,085
Trainable params: 15,678,085
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.07
Forward/backward pass size (MB): 82.62
Params size (MB): 59.81
Estimated Total Size (MB): 142.50
----------------------------------------------------------------
In [103]:
for epoch in range(epochs):
    loss_list = []
    time_one = time.time()
    for data in train_loader:
        images = data[0].permute(0, 3, 1, 2).to(device)
        labels = data[1].permute(0, 3, 1, 2).to(device)

        outputs = net(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_list.append(loss)

    diff_time = time.time() - time_one
    print(f"Epoch: {epoch + 1}/{epochs}, "
          f"Loss: {torch.stack(loss_list).mean():.4f}, "
          f"Time: {diff_time:.2f} ")
Epoch: 1/80, Loss: 0.3077, Time: 7.48 
Epoch: 2/80, Loss: 0.2391, Time: 7.38 
Epoch: 3/80, Loss: 0.2211, Time: 7.41 
Epoch: 4/80, Loss: 0.2108, Time: 7.45 
Epoch: 5/80, Loss: 0.2105, Time: 7.52 
Epoch: 6/80, Loss: 0.2019, Time: 7.51 
Epoch: 7/80, Loss: 0.2028, Time: 7.57 
Epoch: 8/80, Loss: 0.1961, Time: 7.58 
Epoch: 9/80, Loss: 0.1915, Time: 7.62 
Epoch: 10/80, Loss: 0.1883, Time: 7.64 
Epoch: 11/80, Loss: 0.1857, Time: 7.66 
Epoch: 12/80, Loss: 0.1825, Time: 7.67 
Epoch: 13/80, Loss: 0.1807, Time: 7.72 
Epoch: 14/80, Loss: 0.1779, Time: 7.69 
Epoch: 15/80, Loss: 0.1722, Time: 7.70 
Epoch: 16/80, Loss: 0.1699, Time: 7.76 
Epoch: 17/80, Loss: 0.1643, Time: 7.75 
Epoch: 18/80, Loss: 0.1650, Time: 7.77 
Epoch: 19/80, Loss: 0.1615, Time: 7.78 
Epoch: 20/80, Loss: 0.1578, Time: 7.81 
Epoch: 21/80, Loss: 0.1560, Time: 7.81 
Epoch: 22/80, Loss: 0.1531, Time: 7.82 
Epoch: 23/80, Loss: 0.1508, Time: 7.82 
Epoch: 24/80, Loss: 0.1558, Time: 7.83 
Epoch: 25/80, Loss: 0.1512, Time: 7.84 
Epoch: 26/80, Loss: 0.1482, Time: 7.83 
Epoch: 27/80, Loss: 0.1461, Time: 7.84 
Epoch: 28/80, Loss: 0.1379, Time: 7.85 
Epoch: 29/80, Loss: 0.1349, Time: 7.86 
Epoch: 30/80, Loss: 0.1292, Time: 7.87 
Epoch: 31/80, Loss: 0.1305, Time: 7.87 
Epoch: 32/80, Loss: 0.1313, Time: 7.89 
Epoch: 33/80, Loss: 0.1264, Time: 7.88 
Epoch: 34/80, Loss: 0.1251, Time: 7.89 
Epoch: 35/80, Loss: 0.1265, Time: 7.89 
Epoch: 36/80, Loss: 0.1164, Time: 7.90 
Epoch: 37/80, Loss: 0.1135, Time: 7.89 
Epoch: 38/80, Loss: 0.1118, Time: 7.90 
Epoch: 39/80, Loss: 0.1109, Time: 7.90 
Epoch: 40/80, Loss: 0.1096, Time: 7.90 
Epoch: 41/80, Loss: 0.1039, Time: 7.91 
Epoch: 42/80, Loss: 0.1030, Time: 7.90 
Epoch: 43/80, Loss: 0.1020, Time: 7.90 
Epoch: 44/80, Loss: 0.0963, Time: 7.91 
Epoch: 45/80, Loss: 0.0959, Time: 7.90 
Epoch: 46/80, Loss: 0.0929, Time: 7.91 
Epoch: 47/80, Loss: 0.0976, Time: 7.91 
Epoch: 48/80, Loss: 0.0926, Time: 7.91 
Epoch: 49/80, Loss: 0.0867, Time: 7.92 
Epoch: 50/80, Loss: 0.0908, Time: 7.90 
Epoch: 51/80, Loss: 0.0834, Time: 7.93 
Epoch: 52/80, Loss: 0.0850, Time: 7.91 
Epoch: 53/80, Loss: 0.0901, Time: 7.91 
Epoch: 54/80, Loss: 0.0911, Time: 7.93 
Epoch: 55/80, Loss: 0.0862, Time: 7.92 
Epoch: 56/80, Loss: 0.0811, Time: 7.91 
Epoch: 57/80, Loss: 0.0856, Time: 7.90 
Epoch: 58/80, Loss: 0.0789, Time: 7.92 
Epoch: 59/80, Loss: 0.0743, Time: 7.90 
Epoch: 60/80, Loss: 0.0718, Time: 7.91 
Epoch: 61/80, Loss: 0.0682, Time: 7.92 
Epoch: 62/80, Loss: 0.0650, Time: 7.92 
Epoch: 63/80, Loss: 0.0666, Time: 7.90 
Epoch: 64/80, Loss: 0.0652, Time: 7.91 
Epoch: 65/80, Loss: 0.0720, Time: 7.91 
Epoch: 66/80, Loss: 0.0709, Time: 7.92 
Epoch: 67/80, Loss: 0.0669, Time: 7.95 
Epoch: 68/80, Loss: 0.0641, Time: 8.02 
Epoch: 69/80, Loss: 0.0611, Time: 8.06 
Epoch: 70/80, Loss: 0.0585, Time: 8.04 
Epoch: 71/80, Loss: 0.0589, Time: 8.00 
Epoch: 72/80, Loss: 0.0586, Time: 7.97 
Epoch: 73/80, Loss: 0.0573, Time: 7.93 
Epoch: 74/80, Loss: 0.0557, Time: 7.92 
Epoch: 75/80, Loss: 0.0541, Time: 7.93 
Epoch: 76/80, Loss: 0.0546, Time: 7.92 
Epoch: 77/80, Loss: 0.0548, Time: 7.93 
Epoch: 78/80, Loss: 0.0529, Time: 7.94 
Epoch: 79/80, Loss: 0.0512, Time: 7.94 
Epoch: 80/80, Loss: 0.0521, Time: 7.95 
In [104]:
def IoU_bin(labels, predict):
    intersection = np.logical_and(labels, predict)
    union = np.logical_or(labels, predict)
    iou_score = np.sum(intersection) / (0.0001 + np.sum(union))
    return iou_score


def IoU_accuracy_cm_compute(X):
    accuracy_list, IoU_list = [], []
    cm_label, cm_pred = [], []
    with torch.no_grad():
        for images, labels in X:
            images = images.permute(0, 3, 1, 2).to(device)
            labels = labels.permute(0, 3, 1, 2).numpy()
            predict = net(images)
            predict = torch.where(predict < torch.tensor(0.5),
                                            torch.tensor(0),
                                            torch.tensor(1)).cpu().numpy()

            for elem in range(1, 6):
                elem = elem - 1
                labels_5_int = labels[:, elem, :, :]
                predict_5_int = predict[:, elem, :, :]

                labels_5_int = np.where(labels_5_int != 1, 0, elem + 1)
                predict_5_int = np.where(predict_5_int != 1, 0, elem + 1)
                labels[:, elem, :, :] = labels_5_int
                predict[:, elem, :, :] = predict_5_int

            cm_label.append(labels)
            cm_pred.append(predict)

            temp_acc, temp_iou = [], []
            for k in range(5):
                acc = np.mean(predict[:, k, :, :] == labels[:, k, :, :])
                IoU = IoU_bin(labels[:, k, :, :], predict[:, k, :, :])
                temp_acc.append(acc)
                temp_iou.append(IoU)
            accuracy_list.append(temp_acc)
            IoU_list.append(temp_iou)

    accuracy_list = np.array(accuracy_list)
    IoU_list = np.array(IoU_list)

    print(f'Оценка Acc для классов: {np.mean(accuracy_list, axis = 0, dtype = np.float16)}')
    print(f'Оценка IoU для классов: {np.mean(IoU_list, axis = 0, dtype = np.float16)}')

    print(f'Оценка Acc на данных: {np.mean(accuracy_list, dtype = np.float16):.4f}')
    print(f'Оценка IoU на данных: {np.mean(IoU_list, dtype = np.float16):.4f}')

    cm_label = np.concatenate(cm_label, axis = 0)
    cm_pred = np.concatenate(cm_pred, axis = 0)
    cm_label = cm_label.flatten()
    cm_pred = cm_pred.flatten()

    name_class = ['Background', 'Human', 'Robot', 'Fish', 'NewClass']
    cm = confusion_matrix(cm_label, cm_pred, labels = np.arange(5), normalize = 'true')
    sea.heatmap(cm, annot = True, cmap = 'Blues', xticklabels = name_class, yticklabels = name_class)
    plt.xlabel('Предсказанные классы')
    plt.ylabel('Истинные классы')
    plt.title('Confusion matrix')
    plt.show()

IoU_accuracy_cm_compute(test_loader)
Оценка Acc для классов: [0.931  0.983  0.9956 0.9385 0.9053]
Оценка IoU для классов: [0.8345 0.3665 0.2418 0.4949 0.8223]
Оценка Acc на данных: 0.9507
Оценка IoU на данных: 0.5518
No description has been provided for this image
In [140]:
def plot_images_with_masks_test(dataset):
    vert_size = 12
    horiz_size = 2
    fig, axes = plt.subplots(vert_size, horiz_size * 3, figsize = (15, 25))
    fig.suptitle("Predicted vs. True")

    mask_sizes = (image_size, image_size, 3)

    count_images = vert_size * horiz_size
    for number in range(count_images):
        i = number // horiz_size
        j = number % horiz_size

        image, mask = dataset[number]

        axes[i, j * 3].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 3].set_title('Image', fontsize = 10)
        axes[i, j * 3].axis('off')

        with torch.no_grad():
            images = image
            images = images.unsqueeze(0)
            images = images.permute(0, 3, 1, 2).to(device)
            predict_mask = net(images)
            predict_mask = torch.where(predict_mask < torch.tensor(0.5), torch.tensor(0), torch.tensor(1)).permute(0, 2, 3, 1).cpu()

        rgb_predicted_mask = np.zeros(mask_sizes)
        for k in range(number_classes):
            rgb_predicted_mask[predict_mask[0, :, :, k] > 0] = color_classes[k]

        axes[i, j * 3 + 1].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 3 + 1].imshow(rgb_predicted_mask, alpha = 0.35)
        axes[i, j * 3 + 1].set_title('Predicted', fontsize = 10)
        axes[i, j * 3 + 1].axis('off')

        rgb_mask = np.zeros(mask_sizes)
        for k in range(number_classes):
            rgb_mask[mask[:, :, k] > 0] = color_classes[k]

        axes[i, j * 3 + 2].imshow(image, cmap=plt.cm.binary)
        axes[i, j * 3 + 2].imshow(rgb_mask, alpha = 0.35)
        axes[i, j * 3 + 2].set_title('True', fontsize = 10)
        axes[i, j * 3 + 2].axis('off')

    plt.tight_layout()
    plt.show()
In [141]:
plot_images_with_masks_test(test_dataset)
No description has been provided for this image